{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate the Fine-tuned Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the previous sections, we prepared the dataset and fine-tuned the model. In this tutorial, we will go through how to evaluate the model with the test dataset we constructed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "% pip install -U datasets pytrec_eval FlagEmbedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first load data from the files we processed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "queries = load_dataset(\"json\", data_files=\"ft_data/test_queries.jsonl\")[\"train\"]\n",
    "corpus = load_dataset(\"json\", data_files=\"ft_data/corpus.jsonl\")[\"train\"]\n",
    "qrels = load_dataset(\"json\", data_files=\"ft_data/test_qrels.jsonl\")[\"train\"]\n",
    "\n",
    "queries_text = queries[\"text\"]\n",
    "corpus_text = [text for sub in corpus[\"text\"] for text in sub]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "qrels_dict = {}\n",
    "for line in qrels:\n",
    "    if line['qid'] not in qrels_dict:\n",
    "        qrels_dict[line['qid']] = {}\n",
    "    qrels_dict[line['qid']][line['docid']] = line['relevance']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Search"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we prepare a function to encode the text into embeddings and search the results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import faiss\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "def search(model, queries_text, corpus_text):\n",
    "    \n",
    "    queries_embeddings = model.encode_queries(queries_text)\n",
    "    corpus_embeddings = model.encode_corpus(corpus_text)\n",
    "    \n",
    "    # create and store the embeddings in a Faiss index\n",
    "    dim = corpus_embeddings.shape[-1]\n",
    "    index = faiss.index_factory(dim, 'Flat', faiss.METRIC_INNER_PRODUCT)\n",
    "    corpus_embeddings = corpus_embeddings.astype(np.float32)\n",
    "    index.train(corpus_embeddings)\n",
    "    index.add(corpus_embeddings)\n",
    "    \n",
    "    query_size = len(queries_embeddings)\n",
    "\n",
    "    all_scores = []\n",
    "    all_indices = []\n",
    "\n",
    "    # search top 100 answers for all the queries\n",
    "    for i in tqdm(range(0, query_size, 32), desc=\"Searching\"):\n",
    "        j = min(i + 32, query_size)\n",
    "        query_embedding = queries_embeddings[i: j]\n",
    "        score, indice = index.search(query_embedding.astype(np.float32), k=100)\n",
    "        all_scores.append(score)\n",
    "        all_indices.append(indice)\n",
    "\n",
    "    all_scores = np.concatenate(all_scores, axis=0)\n",
    "    all_indices = np.concatenate(all_indices, axis=0)\n",
    "    \n",
    "    # store the results into the format for evaluation\n",
    "    results = {}\n",
    "    for idx, (scores, indices) in enumerate(zip(all_scores, all_indices)):\n",
    "        results[queries[\"id\"][idx]] = {}\n",
    "        for score, index in zip(scores, indices):\n",
    "            if index != -1:\n",
    "                results[queries[\"id\"][idx]][corpus[\"id\"][index]] = float(score)\n",
    "                \n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from FlagEmbedding.abc.evaluation.utils import evaluate_metrics, evaluate_mrr\n",
    "from FlagEmbedding import FlagModel\n",
    "\n",
    "k_values = [10,100]\n",
    "\n",
    "raw_name = \"BAAI/bge-large-en-v1.5\"\n",
    "finetuned_path = \"test_encoder_only_base_bge-large-en-v1.5\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result for the original model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "pre tokenize: 100%|██████████| 3/3 [00:00<00:00, 129.75it/s]\n",
      "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "Inference Embeddings: 100%|██████████| 3/3 [00:00<00:00, 11.08it/s]\n",
      "pre tokenize: 100%|██████████| 28/28 [00:00<00:00, 164.29it/s]\n",
      "Inference Embeddings: 100%|██████████| 28/28 [00:04<00:00,  6.09it/s]\n",
      "Searching: 100%|██████████| 22/22 [00:08<00:00,  2.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "defaultdict(<class 'list'>, {'NDCG@10': 0.70405, 'NDCG@100': 0.73528})\n",
      "defaultdict(<class 'list'>, {'MAP@10': 0.666, 'MAP@100': 0.67213})\n",
      "defaultdict(<class 'list'>, {'Recall@10': 0.82286, 'Recall@100': 0.97286})\n",
      "defaultdict(<class 'list'>, {'P@10': 0.08229, 'P@100': 0.00973})\n",
      "defaultdict(<class 'list'>, {'MRR@10': 0.666, 'MRR@100': 0.67213})\n"
     ]
    }
   ],
   "source": [
    "raw_model = FlagModel(\n",
    "    raw_name, \n",
    "    query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
    "    devices=[0],\n",
    "    use_fp16=False\n",
    ")\n",
    "\n",
    "results = search(raw_model, queries_text, corpus_text)\n",
    "\n",
    "eval_res = evaluate_metrics(qrels_dict, results, k_values)\n",
    "mrr = evaluate_mrr(qrels_dict, results, k_values)\n",
    "\n",
    "for res in eval_res:\n",
    "    print(res)\n",
    "print(mrr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then the result for the model after fine-tuning:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "pre tokenize: 100%|██████████| 3/3 [00:00<00:00, 164.72it/s]\n",
      "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "Inference Embeddings: 100%|██████████| 3/3 [00:00<00:00,  9.45it/s]\n",
      "pre tokenize: 100%|██████████| 28/28 [00:00<00:00, 160.19it/s]\n",
      "Inference Embeddings: 100%|██████████| 28/28 [00:04<00:00,  6.06it/s]\n",
      "Searching: 100%|██████████| 22/22 [00:07<00:00,  2.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "defaultdict(<class 'list'>, {'NDCG@10': 0.84392, 'NDCG@100': 0.85792})\n",
      "defaultdict(<class 'list'>, {'MAP@10': 0.81562, 'MAP@100': 0.81875})\n",
      "defaultdict(<class 'list'>, {'Recall@10': 0.93143, 'Recall@100': 0.99429})\n",
      "defaultdict(<class 'list'>, {'P@10': 0.09314, 'P@100': 0.00994})\n",
      "defaultdict(<class 'list'>, {'MRR@10': 0.81562, 'MRR@100': 0.81875})\n"
     ]
    }
   ],
   "source": [
    "ft_model = FlagModel(\n",
    "    finetuned_path, \n",
    "    query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
    "    devices=[0],\n",
    "    use_fp16=False\n",
    ")\n",
    "\n",
    "results = search(ft_model, queries_text, corpus_text)\n",
    "\n",
    "eval_res = evaluate_metrics(qrels_dict, results, k_values)\n",
    "mrr = evaluate_mrr(qrels_dict, results, k_values)\n",
    "\n",
    "for res in eval_res:\n",
    "    print(res)\n",
    "print(mrr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see an obvious improvement in all the metrics."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ft",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
